/*
 * Copyright  2017 NXP
 * All rights reserved.
 *
 *
 * SPDX-License-Identifier: BSD-3-Clause
 */

#include <stdio.h>

#include "fsl_common.h"
#include "board.h"
#include "pin_mux.h"

#include "timer.h"

#include "arm_math.h"
#include "arm_const_structs.h"

#include "fsl_powerquad.h"

/*******************************************************************************
 * Definitions
 ******************************************************************************/
#define DCT2D_MATRIX_N  8u
#define POWERQUAD_PRIVATE_RAM_BASE  0xE0000000

#define APP_ENABLE_PRINT_RESULT  0

/*******************************************************************************
 * Prototypes
 ******************************************************************************/
void dct2d_powerquad_init(void);
float powerquad_inv(float input);
float powerquad_cos(float input);
float powerquad_sqrt(float input);

void print_matrix_f32(float *matrix, uint32_t row_num, uint32_t col_num);
void print_matrix_int32(int32_t *matrix, uint32_t row_num, uint32_t col_num);

/*******************************************************************************
 * Variables
 ******************************************************************************/
uint32_t app_timer_ticks;

void dct2d_arm_f32(void);
void dct2d_arm_int32(void);
void dct2d_powerquad_f32(void);
void dct2d_powerquad_int32(void);

void app_func_end(void)
{
    printf("\r\n----- END -----\r\n");
}

typedef void (* func_0_t)(void);
func_0_t app_funcs[] =
{
    dct2d_arm_f32,
    dct2d_arm_int32,
    dct2d_powerquad_f32,
    dct2d_powerquad_int32,
    app_func_end
};

uint32_t powerquad_matrix_inv_tmp[1024];
float dct2d_A_matrix[DCT2D_MATRIX_N][DCT2D_MATRIX_N];
float dct2d_A_inv_matrix[DCT2D_MATRIX_N][DCT2D_MATRIX_N];
float dct2d_input_matrix[DCT2D_MATRIX_N][DCT2D_MATRIX_N];
float dct2d_output_matrix[DCT2D_MATRIX_N][DCT2D_MATRIX_N];
float dct2d_tmp_matrix[DCT2D_MATRIX_N][DCT2D_MATRIX_N];

int32_t dct2d_A_matrix_int32[DCT2D_MATRIX_N][DCT2D_MATRIX_N];
int32_t dct2d_A_inv_matrix_int32[DCT2D_MATRIX_N][DCT2D_MATRIX_N];
int32_t dct2d_input_matrix_int32[DCT2D_MATRIX_N][DCT2D_MATRIX_N];
int32_t dct2d_output_matrix_int32[DCT2D_MATRIX_N][DCT2D_MATRIX_N];
int32_t dct2d_tmp_matrix_int32[DCT2D_MATRIX_N][DCT2D_MATRIX_N];

/*******************************************************************************
 * Code
 ******************************************************************************/


/*!
 * @brief Main function
 */
int main(void)
{
    board_init();
    printf("powerquad dct 2d benchmark.\r\n");

    /* systick. */
    timer_init();

    dct2d_powerquad_init();

    while (1)
    {
        for (uint32_t i = 0u; i < sizeof(app_funcs)/sizeof(app_funcs[0]); i++)
        {
            char ch = getchar(); putchar(ch);
            (app_funcs[i])();
        }
    }
}

void dct2d_powerquad_init(void)
{
    printf("\r\n%s\r\n", __func__);
    /* powerquad. */
    PQ_Init(POWERQUAD);

    //dct2d_powerquad_calc_matrix_A();
    uint32_t i, j;
    float factor1, factor2;
    uint32_t length = POWERQUAD_MAKE_MATRIX_LEN(DCT2D_MATRIX_N, DCT2D_MATRIX_N, DCT2D_MATRIX_N);

    timer_start();

    /* 生成转换矩阵A. */
    factor2 = powerquad_inv((float)DCT2D_MATRIX_N);
    factor1 = powerquad_sqrt(factor2);
    for (j = 0u; j < DCT2D_MATRIX_N; j++)
    {
        dct2d_A_matrix[0][j] = factor1;
    }

    factor1 = powerquad_sqrt(2.0f  * factor2);
    for (i = 1u; i < DCT2D_MATRIX_N; i++)
    {
        for (j = 0u; j < DCT2D_MATRIX_N; j++)
        {
            dct2d_A_matrix[i][j] = factor1 * powerquad_cos(i*(2*j+1)*(0.5f * PI * factor2));
        }
    }

    PQ_SetFormat(POWERQUAD, kPQ_CP_MTX, kPQ_Float);
    PQ_MatrixInversion(POWERQUAD, length, dct2d_A_matrix, powerquad_matrix_inv_tmp, dct2d_A_inv_matrix);
    PQ_WaitDone(POWERQUAD);

    timer_stop(app_timer_ticks);

    printf("matrix_A:\r\n");
    print_matrix_f32((float *)dct2d_A_matrix, DCT2D_MATRIX_N, DCT2D_MATRIX_N);
    printf("\r\n");

    printf("matrix_A_inv:\r\n");
    print_matrix_f32((float *)dct2d_A_inv_matrix, DCT2D_MATRIX_N, DCT2D_MATRIX_N);
    printf("\r\n");

    printf(" * %d us\r\n", app_timer_ticks / TIMER_TICKS_PER_US);
}

/* 使用cmsis-dsp的api进行计算, 浮点数实现. */
void dct2d_arm_f32(void)
{
    printf("\r\n%s().\r\n", __func__);

    arm_matrix_instance_f32 dct2d_A_matrix_instance =
    {
        .pData = (float *)dct2d_A_matrix,
        .numCols = DCT2D_MATRIX_N,
        .numRows = DCT2D_MATRIX_N
    };

    arm_matrix_instance_f32 dct2d_A_inv_matrix_instance =
    {
        .pData = (float *)dct2d_A_inv_matrix,
        .numCols = DCT2D_MATRIX_N,
        .numRows = DCT2D_MATRIX_N
    };

    arm_matrix_instance_f32 dct2d_input_matrix_instance =
    {
        .pData = (float *)dct2d_input_matrix,
        .numCols = DCT2D_MATRIX_N,
        .numRows = DCT2D_MATRIX_N
    };
    arm_matrix_instance_f32 dct2d_output_matrix_instance =
    {
        .pData = (float *)dct2d_output_matrix,
        .numCols = DCT2D_MATRIX_N,
        .numRows = DCT2D_MATRIX_N
    };
    arm_matrix_instance_f32 dct2d_tmp_matrix_instance =
    {
        .pData = (float *)dct2d_tmp_matrix,
        .numCols = DCT2D_MATRIX_N,
        .numRows = DCT2D_MATRIX_N
    };

    /* 生成测试用的输入矩阵. */
    uint32_t i, j;
    for (i = 0u; i < DCT2D_MATRIX_N; i++)
    {
        for (j = 0u; j < DCT2D_MATRIX_N; j++)
        {
            dct2d_input_matrix[i][j] = i * 10.0f + j;
        }
    }

    printf("input_matrix_f32:\r\n");
    print_matrix_f32((float *)dct2d_input_matrix, DCT2D_MATRIX_N, DCT2D_MATRIX_N);
    printf("\r\n");

    /* DCT forward transform. */
    memset(dct2d_output_matrix, 0, sizeof(dct2d_output_matrix));
    memset(dct2d_tmp_matrix, 0, sizeof(dct2d_tmp_matrix));

    timer_start();

    arm_mat_mult_f32(&dct2d_A_matrix_instance, &dct2d_input_matrix_instance, &dct2d_tmp_matrix_instance);
    arm_mat_mult_f32(&dct2d_tmp_matrix_instance, &dct2d_A_inv_matrix_instance, &dct2d_output_matrix_instance);

    timer_stop(app_timer_ticks);

    printf("output_matrix_f32:\r\n");
    print_matrix_f32((float *)dct2d_output_matrix, DCT2D_MATRIX_N, DCT2D_MATRIX_N);
    printf("\r\n");
    printf(" * %d us\r\n", app_timer_ticks / TIMER_TICKS_PER_US);

    /* DCT forward transform. */
    memset(dct2d_input_matrix, 0, sizeof(dct2d_input_matrix));
    memset(dct2d_tmp_matrix, 0, sizeof(dct2d_tmp_matrix));

    timer_start();

    arm_mat_mult_f32(&dct2d_A_inv_matrix_instance, &dct2d_output_matrix_instance, &dct2d_tmp_matrix_instance);
    arm_mat_mult_f32(&dct2d_tmp_matrix_instance, &dct2d_A_matrix_instance, &dct2d_input_matrix_instance);

    timer_stop(app_timer_ticks);

    printf("input_matrix_f32:\r\n");
    print_matrix_f32((float *)dct2d_input_matrix, DCT2D_MATRIX_N, DCT2D_MATRIX_N);
    printf("\r\n");
    printf(" * %d us\r\n", app_timer_ticks / TIMER_TICKS_PER_US);
}


/* 使用cmsis-dsp的api进行计算, 定点数实现.
 * 使用定点数必要处理数据溢出的问题, 由于变换矩阵A的值不能改变,
 * 一般在实际应用中都是通过对输入数据缩放到一定程度, 确保在定点数的范围内能够完成计算, 才开始计算.
 * 之后再想方设法还原, 或者压根不还原, 直接使用相对关系表示信号.
 */
void dct2d_arm_int32(void)
{
    printf("\r\n%s().\r\n", __func__);

    arm_matrix_instance_q31 dct2d_A_matrix_int32_instance =
    {
        .pData = (q31_t *)dct2d_A_matrix_int32,
        .numCols = DCT2D_MATRIX_N,
        .numRows = DCT2D_MATRIX_N
    };

    arm_matrix_instance_q31 dct2d_A_inv_matrix_int32_instance =
    {
        .pData = (q31_t *)dct2d_A_inv_matrix_int32,
        .numCols = DCT2D_MATRIX_N,
        .numRows = DCT2D_MATRIX_N
    };

    arm_matrix_instance_q31 dct2d_input_matrix_int32_instance =
    {
        .pData = (q31_t *)dct2d_input_matrix_int32,
        .numCols = DCT2D_MATRIX_N,
        .numRows = DCT2D_MATRIX_N
    };
    arm_matrix_instance_q31 dct2d_output_matrix_int32_instance =
    {
        .pData = (q31_t *)dct2d_output_matrix_int32,
        .numCols = DCT2D_MATRIX_N,
        .numRows = DCT2D_MATRIX_N
    };
    arm_matrix_instance_q31 dct2d_tmp_matrix_int32_instance =
    {
        .pData = (q31_t *)dct2d_tmp_matrix_int32,
        .numCols = DCT2D_MATRIX_N,
        .numRows = DCT2D_MATRIX_N
    };

    /* 生成测试用的输入矩阵. */
    uint32_t i, j;
    for (i = 0u; i < DCT2D_MATRIX_N; i++)
    {
        for (j = 0u; j < DCT2D_MATRIX_N; j++)
        {
            dct2d_input_matrix_int32[i][j] = (i * 10 + j); /* 直接在定点数的内存空间中写入整数,相当于将原数的数值缩放了2^31倍. */
        }
    }

    printf("input_matrix_int32:\r\n");
    print_matrix_int32((int32_t *)dct2d_input_matrix_int32, DCT2D_MATRIX_N, DCT2D_MATRIX_N);
    printf("\r\n");

    /* 将已填充的矩阵内容全部由浮点数转换成定点数. */
    //arm_float_to_q31((float *)dct2d_input_matrix, (q31_t *)dct2d_input_matrix_int32, DCT2D_MATRIX_N * DCT2D_MATRIX_N);
    arm_float_to_q31((float *)dct2d_A_matrix, (q31_t *)dct2d_A_matrix_int32, DCT2D_MATRIX_N * DCT2D_MATRIX_N);
    arm_float_to_q31((float *)dct2d_A_inv_matrix, (q31_t *)dct2d_A_inv_matrix_int32, DCT2D_MATRIX_N * DCT2D_MATRIX_N);

    /* DCT forward transform. */
    memset(dct2d_output_matrix_int32, 0, sizeof(dct2d_output_matrix_int32));
    memset(dct2d_tmp_matrix_int32, 0, sizeof(dct2d_tmp_matrix_int32));

    timer_start();

    arm_mat_mult_q31(&dct2d_A_matrix_int32_instance, &dct2d_input_matrix_int32_instance, &dct2d_tmp_matrix_int32_instance);
    arm_mat_mult_q31(&dct2d_tmp_matrix_int32_instance, &dct2d_A_inv_matrix_int32_instance, &dct2d_output_matrix_int32_instance);

    timer_stop(app_timer_ticks);

    printf("output_matrix_int32:\r\n");
    print_matrix_int32((int32_t *)dct2d_output_matrix_int32, DCT2D_MATRIX_N, DCT2D_MATRIX_N);
    printf("\r\n");
    printf(" * %d us\r\n", app_timer_ticks / TIMER_TICKS_PER_US);

    /* DCT forward transform. */
    memset(dct2d_input_matrix_int32, 0, sizeof(dct2d_input_matrix_int32));
    memset(dct2d_tmp_matrix_int32, 0, sizeof(dct2d_tmp_matrix_int32));

    timer_start();

    arm_mat_mult_q31(&dct2d_A_inv_matrix_int32_instance, &dct2d_output_matrix_int32_instance, &dct2d_tmp_matrix_int32_instance);
    arm_mat_mult_q31(&dct2d_tmp_matrix_int32_instance, &dct2d_A_matrix_int32_instance, &dct2d_input_matrix_int32_instance);

    timer_stop(app_timer_ticks);

    printf("input_matrix_int32:\r\n");
    print_matrix_int32((int32_t *)dct2d_input_matrix_int32, DCT2D_MATRIX_N, DCT2D_MATRIX_N);
    printf("\r\n");
    printf(" * %d us\r\n", app_timer_ticks / TIMER_TICKS_PER_US);

}

/* 使用PowerQuad, 浮点数计算. */
void dct2d_powerquad_f32(void)
{
    printf("\r\n%s().\r\n", __func__);

    PQ_SetFormat(POWERQUAD, kPQ_CP_MTX, kPQ_Float);
    uint32_t length = POWERQUAD_MAKE_MATRIX_LEN(DCT2D_MATRIX_N, DCT2D_MATRIX_N, DCT2D_MATRIX_N);

    /* 生成测试用的输入矩阵. */
    uint32_t i, j;
    for (i = 0u; i < DCT2D_MATRIX_N; i++)
    {
        for (j = 0u; j < DCT2D_MATRIX_N; j++)
        {
            dct2d_input_matrix[i][j] = i * 10.0f + j;
        }
    }

    printf("input_matrix_f32:\r\n");
    print_matrix_f32((float *)dct2d_input_matrix, DCT2D_MATRIX_N, DCT2D_MATRIX_N);
    printf("\r\n");

    /* DCT forward transform. */
    memset(dct2d_output_matrix, 0, sizeof(dct2d_output_matrix));

    timer_start();

    PQ_MatrixMultiplication(POWERQUAD, length, dct2d_A_matrix, dct2d_input_matrix, (void *)POWERQUAD_PRIVATE_RAM_BASE);
    PQ_WaitDone(POWERQUAD);
    PQ_MatrixMultiplication(POWERQUAD, length, (void *)POWERQUAD_PRIVATE_RAM_BASE, dct2d_A_inv_matrix, dct2d_output_matrix);
    PQ_WaitDone(POWERQUAD);

    timer_stop(app_timer_ticks);

    printf("output_matrix_f32:\r\n");
    print_matrix_f32((float *)dct2d_output_matrix, DCT2D_MATRIX_N, DCT2D_MATRIX_N);
    printf("\r\n");
    printf(" * %d us\r\n", app_timer_ticks / TIMER_TICKS_PER_US);

    /* DCT backword transform. */
    memset(dct2d_input_matrix, 0, sizeof(dct2d_input_matrix));

    timer_start();

    PQ_MatrixMultiplication(POWERQUAD, length, dct2d_A_inv_matrix, dct2d_output_matrix, (void *)POWERQUAD_PRIVATE_RAM_BASE);
    PQ_WaitDone(POWERQUAD);
    PQ_MatrixMultiplication(POWERQUAD, length, (void *)POWERQUAD_PRIVATE_RAM_BASE, dct2d_A_matrix, dct2d_input_matrix);
    PQ_WaitDone(POWERQUAD);

    timer_stop(app_timer_ticks);

    printf("input_matrix_f32:\r\n");
    print_matrix_f32((float *)dct2d_input_matrix, DCT2D_MATRIX_N, DCT2D_MATRIX_N);
    printf("\r\n");
    printf(" * %d us\r\n", app_timer_ticks / TIMER_TICKS_PER_US);
}


/* 使用PowerQuad, 定点数计算.
 * a0版芯片的powerquad在接口上的数据格式转换上(定点数与浮点数转换)有硬件bug. 此函数仅在1b版芯片上运行有效(已经修复).
 */
void dct2d_powerquad_int32(void)
{
    printf("\r\n%s().\r\n", __func__);

    uint32_t i, j;
    for (i = 0u; i < DCT2D_MATRIX_N; i++)
    {
        for (j = 0u; j < DCT2D_MATRIX_N; j++)
        {
            dct2d_input_matrix_int32[i][j] = i * 10u + j;
        }
    }

    printf("input_matrix_int32:\r\n");
    print_matrix_int32((int32_t *)dct2d_input_matrix_int32, DCT2D_MATRIX_N, DCT2D_MATRIX_N);
    printf("\r\n");

    /* prepare the settings for powerquad. */
    uint32_t length = POWERQUAD_MAKE_MATRIX_LEN(DCT2D_MATRIX_N, DCT2D_MATRIX_N, DCT2D_MATRIX_N);
    pq_config_t pq_config;
    pq_config.inputAPrescale = 0;
    pq_config.inputBPrescale = 0;
    pq_config.outputPrescale = 0;
    pq_config.tmpFormat = kPQ_Float;
    pq_config.tmpPrescale = 0;
    pq_config.machineFormat = kPQ_Float;
    pq_config.tmpBase = (void *)POWERQUAD_PRIVATE_RAM_BASE;

    /* DCT forward transform. */
    memset(dct2d_output_matrix_int32, 0, sizeof(dct2d_output_matrix_int32));

    timer_start();

    pq_config.inputAFormat = kPQ_Float;
    pq_config.inputBFormat = kPQ_32Bit;
    pq_config.outputFormat = kPQ_Float;
    PQ_SetConfig(POWERQUAD, &pq_config);
    PQ_MatrixMultiplication(POWERQUAD, length, dct2d_A_matrix, dct2d_input_matrix_int32, (void *)POWERQUAD_PRIVATE_RAM_BASE);
    PQ_WaitDone(POWERQUAD);

    pq_config.inputAFormat = kPQ_Float;
    pq_config.inputBFormat = kPQ_Float;
    pq_config.outputFormat = kPQ_32Bit;
    PQ_SetConfig(POWERQUAD, &pq_config);
    PQ_MatrixMultiplication(POWERQUAD, length, (void *)POWERQUAD_PRIVATE_RAM_BASE, dct2d_A_inv_matrix, dct2d_output_matrix_int32);
    PQ_WaitDone(POWERQUAD);

    timer_stop(app_timer_ticks);

    printf("output_matrix_int32:\r\n");
    print_matrix_int32((int32_t *)dct2d_output_matrix_int32, DCT2D_MATRIX_N, DCT2D_MATRIX_N);
    printf("\r\n");
    printf(" * %d us\r\n", app_timer_ticks / TIMER_TICKS_PER_US);

    /* DCT backword transform. */
    memset(dct2d_input_matrix_int32, 0, sizeof(dct2d_input_matrix_int32));

    timer_start();

    pq_config.inputAFormat = kPQ_Float;
    pq_config.inputBFormat = kPQ_32Bit;
    pq_config.outputFormat = kPQ_Float;
    PQ_SetConfig(POWERQUAD, &pq_config);
    PQ_MatrixMultiplication(POWERQUAD, length, dct2d_A_inv_matrix, dct2d_output_matrix_int32, (void *)POWERQUAD_PRIVATE_RAM_BASE);
    PQ_WaitDone(POWERQUAD);

    pq_config.inputAFormat = kPQ_Float;
    pq_config.inputBFormat = kPQ_Float;
    pq_config.outputFormat = kPQ_32Bit;
    PQ_SetConfig(POWERQUAD, &pq_config);
    PQ_MatrixMultiplication(POWERQUAD, length, (void *)POWERQUAD_PRIVATE_RAM_BASE, dct2d_A_matrix, dct2d_input_matrix_int32);
    PQ_WaitDone(POWERQUAD);

    timer_stop(app_timer_ticks);

    printf("input_matrix_int32:\r\n");
    print_matrix_int32((int32_t *)dct2d_input_matrix_int32, DCT2D_MATRIX_N, DCT2D_MATRIX_N);
    printf("\r\n");
    printf(" * %d us\r\n", app_timer_ticks / TIMER_TICKS_PER_US);
}

float powerquad_inv(float input)
{
    float output;
    PQ_InvF32(&input, &output);
    return output;
}

float powerquad_cos(float input)
{
    float output;
    PQ_CosF32(&input, &output);
    return output;
}

float powerquad_sqrt(float input)
{
    float output;
    PQ_SqrtF32(&input, &output);
    return output;
}

void print_matrix_f32(float *matrix, uint32_t row_num, uint32_t col_num)
{
#if APP_ENABLE_PRINT_RESULT
    for (uint32_t i = 0u; i < row_num; i++)
    {
        for (uint32_t j = 0u; j < col_num; j++)
        {
            printf("%10.5f ", *matrix++);
        }
        printf("\r\n");
    }
#endif /* APP_ENABLE_PRINT_RESULT */
}

void print_matrix_int32(int32_t *matrix, uint32_t row_num, uint32_t col_num)
{
#if APP_ENABLE_PRINT_RESULT
    for (uint32_t i = 0u; i < row_num; i++)
    {
        for (uint32_t j = 0u; j < col_num; j++)
        {
            printf("%10d ", *matrix++);
        }
        printf("\r\n");
    }
#endif /* APP_ENABLE_PRINT_RESULT */
}

/* EOF. */

